Skip to content

Conversation

@tianyu-l
Copy link
Contributor

@tianyu-l tianyu-l commented Aug 6, 2025

Given the complexity of MoE and EP modules
This PR

  1. creates torchtitan/models/moe.py as the central moe implementation (this is similar to why we have torchtitan/models/attention.py)
  2. creates torchtitan/distributed/expert_parallel.py as the central EP implementation
  3. rename torchtitan/distributed/pipeline.py -> torchtitan/distributed/pipeline_parallel.py to be consistent with EP
  4. apply temporary fix by @rakkit possible memory leaking of DP2EP with recompute #1467 before the memory leak issue with AC + PT-D all_to_all_single_autograd is fixed (cc @soulitzer)

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 6, 2025
@tianyu-l tianyu-l linked an issue Aug 6, 2025 that may be closed by this pull request
Copy link
Contributor

@danielvegamyhre danielvegamyhre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, left a couple of minor comments/questions.

import torch.nn as nn
from torch.distributed._functional_collectives import all_to_all_single_autograd

# from torch.distributed._functional_collectives import all_to_all_single_autograd
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove commented code

Copy link
Contributor Author

@tianyu-l tianyu-l Aug 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is intentional -- we should restore this implementation after bug is fixed. I reorganized the code a bit to make it clearer.

@staticmethod
def forward(ctx, x, out_splits, in_splits, group):
if isinstance(out_splits, torch.Tensor):
out_splits = out_splits.tolist()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

won't tolist() cause d2h sync? is this okay / intentional in this case?

Copy link
Contributor Author

@tianyu-l tianyu-l Aug 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will. This is a temporary fix, but currently in EP there are multiple places with d2h sync. I'm working on another implementation to kill them.

Copy link
Contributor

@ruisizhang123 ruisizhang123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for the refactor.

@tianyu-l tianyu-l force-pushed the cleanup branch 2 times, most recently from 85dc2ad to 16ad9f5 Compare August 6, 2025 02:44

@dataclass
class MoEArgs:
moe_enabled: bool = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we need to have moe_enabled in MoEArgs?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't see anywhere this parameter is false

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, removed

Copy link
Contributor

@wwwjn wwwjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Nice refactor

@rakkit
Copy link
Contributor

rakkit commented Aug 6, 2025

(sry, I am leaving for vacation and lazy to open PR).
Since you are refactoring MOE, here is another trick or bug we found. When we use DP2EP, the gradient reduction denominator (of EP) is always smaller than other modules(e.g. attention or embedding) so the actual gradient of EP is always higher. See logs.

image

we fixed this by adding a loss_average_denominator in parallel_dims.py and forcing EP's reduce denominator to

            transformer_block.feed_forward.experts.set_reduce_scatter_divide_factor(
                loss_average_denominator,
            )

in apply_fsdp, you can also check full code here

where loss_average_denominator = dp_replicate * dp_shard *cp see here


We also have a better version of the bias update that only needs one reduce, you can check the code here. A tricks here does not affect to Bias update, but we need to know, once we have activation checkpoint.

        if self.load_balance_coeff is not None:
           with torch.no_grad():
               self.tokens_per_expert.add_(num_tokens_per_expert)

It will be called more than once, so the actual stats value of num_tokens_per_expert will be 2x.

@tianyu-l tianyu-l merged commit a9aa506 into main Aug 6, 2025
7 checks passed
@tianyu-l tianyu-l deleted the cleanup branch August 6, 2025 05:47
@tianyu-l tianyu-l mentioned this pull request Aug 6, 2025
joellidin pushed a commit to one-covenant/torchtitan that referenced this pull request Aug 8, 2025
Given the complexity of MoE and EP modules
This PR
1. creates `torchtitan/models/moe.py` as the central moe implementation
(this is similar to why we have `torchtitan/models/attention.py`)
2. creates `torchtitan/distributed/expert_parallel.py` as the central EP
implementation
3. rename `torchtitan/distributed/pipeline.py` ->
`torchtitan/distributed/pipeline_parallel.py` to be consistent with EP
4. apply temporary fix by @rakkit
pytorch#1467 before the memory leak
issue with AC + PT-D all_to_all_single_autograd is fixed (cc @soulitzer)
joellidin pushed a commit to one-covenant/torchtitan that referenced this pull request Aug 8, 2025
Given the complexity of MoE and EP modules
This PR
1. creates `torchtitan/models/moe.py` as the central moe implementation
(this is similar to why we have `torchtitan/models/attention.py`)
2. creates `torchtitan/distributed/expert_parallel.py` as the central EP
implementation
3. rename `torchtitan/distributed/pipeline.py` ->
`torchtitan/distributed/pipeline_parallel.py` to be consistent with EP
4. apply temporary fix by @rakkit
pytorch#1467 before the memory leak
issue with AC + PT-D all_to_all_single_autograd is fixed (cc @soulitzer)
tianyu-l added a commit that referenced this pull request Aug 11, 2025
@garrett361
Copy link
Contributor

Nice @rakkit, we found the same issue with the ep grads being off by a factor. I was finding that set_reduce_scatter_divide_factor errored when using an mp policy, though.

Surprised you didn't hit that? Think I saw you're on torch==2.6 in another comment elsewhere

@rakkit
Copy link
Contributor

rakkit commented Aug 13, 2025

lol @garrett361 thanks for the info. I did not see the issue on both Torch 2.6 and 2.7.1.

To clarify I only test the default mp set(mixed_precision_param=bf16 and mixed_precision_reduce=fp32)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Circular imports

7 participants